#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import re
import subprocess
import sys
import zipfile
from functools import wraps

import pyredex.unpacker as unpacker
from pyredex.utils import abs_glob, make_temp_dir

ALL_COMMANDS = {}

def log(*args):
    print(*args, file=sys.stderr)

def command(name, desc):
    def wrap_function(f):
        @wraps(f)
        def wrapper(dexfiles, apk_dir, *args, **kwds):
            return f(dexfiles, apk_dir, *args, **kwds)

        wrapper.name = name
        wrapper.desc = desc
        ALL_COMMANDS[name] = wrapper
        return wrapper
    return wrap_function

# -------------------------- Command Definitions -----------------------------

@command(name='dexdump',
         desc='Calls dexdump with dexes from apk. ' \
              'Output written to stdout. Accepts all dexdump options.')
def dexdump(dexfiles, apk_dir, *args):
    dexdump_args = ['dexdump'] + list(args) + dexfiles
    subprocess.check_call(dexdump_args)


@command(name='redexdump',
         desc='Calls redexdump with dexes from apk. ' \
                'Output written to stdout. Accepts all redexdump options.')
def redexdump(dexfiles, apk_dir, *args):
    redexdump_args = ['redexdump'] + list(args) + dexfiles
    subprocess.check_call(redexdump_args)


@command(name='codesize',
         desc='Prints total and primary dex size')
def codesize(dexfiles, apk_dir, *args):
    one_mb = 1024. * 1024.
    total_size = sum(os.path.getsize(d) for d in dexfiles)
    primary_dex_size = sum(os.path.getsize(d)
            for d in dexfiles if d.endswith('classes.dex'))

    dex_count = len(dexfiles)
    log('Total code size: {:.2f}MB in {} dex files'.format(
            total_size / one_mb, dex_count))
    log('Primary dex size: {:.2f}MB'.format(primary_dex_size / one_mb))


@command(name='classes',
         desc='Prints list of classes. Use -j for java-style names.')
def classes(dexfiles, apk_dir, *args):
    java_style = '-j' in args
    proc = subprocess.Popen(['dexdump', '-f'] + dexfiles,
            stdout=subprocess.PIPE)

    pattern = b"  Class descriptor  : '([^']+)'"
    for line in proc.stdout:
        match = re.match(pattern, line)
        if match:
            classname = match.group(1).decode()
            if java_style:
                classname = classname[1:-1].replace('/', '.')
            print(classname)


# ------------------------- End Command Definitions ---------------------------

def invoke_command(apk_path, cmd, *command_args):
    extracted_apk_dir = make_temp_dir('.redex_extracted_apk')

    log('Extracting apk...')
    with zipfile.ZipFile(apk_path) as z:
        z.extractall(extracted_apk_dir)

    dex_mode = unpacker.detect_secondary_dex_mode(extracted_apk_dir)
    log('Detected dex mode ' + str(type(dex_mode).__name__))
    dex_dir = make_temp_dir('.redex_dexen')

    log('Unpacking dex files')
    dex_mode.unpackage(extracted_apk_dir, dex_dir)

    dex_files = list(abs_glob(dex_dir, '*.dex'))

    log('Running command ' + cmd.name + '...')
    cmd(dex_files, apk_path, *command_args)


def show_usage():
    log('Redex APK Utilties')
    log('Usage: apkutil <path_to_apk> <command> [command args...]\n')
    log('Available commands:')

    for name, tool in sorted(ALL_COMMANDS.items()):
        log('    {:<12} {}'.format(name, tool.desc))


if __name__ == '__main__':
    if len(sys.argv) < 3:
        show_usage()
    else:
        apk_path = sys.argv[1]
        command_name = sys.argv[2]
        if not os.path.isfile(apk_path):
            log(apk_path + ' is not a file')
        if command_name in ALL_COMMANDS:
            args = [apk_path, ALL_COMMANDS[command_name]] + sys.argv[3:]
            invoke_command(*args)
